from simple_block import simple, exog
from solved_block import solved

def phillips_curve(J):
    return exog(J)

# %% standard nk model

@simple
def standard_nk(y, n, w, i, pi, sig, phi, phi_pi, phi_y, eps_i):
    euler_res = y - (y(1) - (1 / sig) * (i - pi(1)))
    labor_res = phi * n + sig * y - w
    prod_res = y - n
    taylor_res = i - (phi_pi * pi + phi_y * y + eps_i)
    return euler_res, labor_res, prod_res, taylor_res


@simple
def pc_inputs_nk(w):
    gap = w
    return gap


unknowns_nk = ['n', 'y', 'w', 'i']
targets_nk = ['euler_res', 'labor_res', 'prod_res', 'taylor_res']
blocks_nk = [standard_nk, pc_inputs_nk]
ss_nk = {'y': 0, 'n': 0, 'w': 0, 'i': 0, 'pi': 0,
         'sig': 1, 'phi': 5, 'phi_pi': 1.5, 'phi_y': 0, 'eps_i': 0}


# %% smets wouters

@solved(unknowns=['c'], targets=['euler_res'])
def household(c, n, r, eps_b, c1, c2, c3):
    euler_res = c1*c(-1) + (1-c1)*c(1) + c2*(n - n(1)) - c3*r + eps_b - c
    return euler_res


@solved(unknowns=['cf'], targets=['euler_res_f'])
def household_f(cf, nf, rf, eps_b, c1, c2, c3):
    euler_res_f = c1*cf(-1) + (1-c1)*cf(1) + c2*(nf - nf(1)) - c3*rf + eps_b - cf
    return euler_res_f


@simple
def household_markup(c, cf, n, nf, w, wf, la_c, ga, sig_l):
    mu_w = w - (sig_l*n + (c-c(-1)*la_c/ga)/(1-la_c/ga))
    mu_w_f = wf - (sig_l*nf + (cf-cf(-1)*la_c/ga)/(1-la_c/ga))
    return mu_w, mu_w_f


@solved(unknowns=['I', 'q', 'rk', 'k'], targets=['I_res', 'q_res', 'rk_res', 'k_res'])
def firm(I, q, k, rk, n, w, r, eps_I, eps_b, eps_a, I1, I2, q1, z1, k1, k2, c3, fixed_cost, alpha):
    I_res = I1 * I(-1) + (1 - I1) * I(1) + I2 * q + eps_I - I  # Investment
    q_res = q1 * q(1) + (1 - q1) * rk(1) - r + eps_b / c3 - q  # Tobin's q
    z = z1 * rk  # Capital utilization
    ks = z + k(-1)
    y = (1 + fixed_cost) * (alpha * ks + (1 - alpha) * n + eps_a)  # Production function
    rk_res = - (ks - n) + w - rk  # Rental rate
    k_res = k1 * k(-1) + (1 - k1) * I + k2 * eps_I - k  # Law of motion of capital
    mu_p = alpha * (ks - n) + eps_a - w  # Price markup
    mc = -mu_p
    return I_res, q_res, rk_res, k_res, ks, z, y, mu_p, mc


@solved(unknowns=['If', 'qf', 'rkf', 'kf'], targets=['I_res_f', 'q_res_f', 'rk_res_f', 'k_res_f'])
def firm_f(If, qf, kf, rkf, nf, wf, rf, eps_I, eps_b, eps_a, I1, I2, q1, z1, k1, k2, c3, fixed_cost, alpha):
    I_res_f = I1 * If(-1) + (1 - I1) * If(1) + I2 * qf + eps_I - If  # Investment
    q_res_f = q1 * qf(1) + (1 - q1) * rkf(1) - rf + eps_b / c3 - qf  # Tobin's q
    zf = z1 * rkf  # Capital utilization
    ksf = zf + kf(-1)
    yf = (1 + fixed_cost) * (alpha * ksf + (1 - alpha) * nf + eps_a)  # Production function
    rk_res_f = - (ksf - nf) + wf - rkf  # Rental rate
    k_res_f = k1 * kf(-1) + (1 - k1) * If + k2 * eps_I - kf  # Law of motion of capital
    mu_p_f = alpha * (ksf - nf) + eps_a - wf  # Price markup
    return I_res_f, q_res_f, rk_res_f, k_res_f, ksf, zf, yf, mu_p_f


@solved(unknowns=['pi'], targets=['pi_res'])
def nkpc_p(pi, mu_p, eps_p, pi1, pi2, pi3):
    pi_res = pi1*pi(-1) + pi2*pi(1) - pi3*mu_p + eps_p - pi
    return pi_res


@solved(unknowns=['wout'], targets=['w_res'])
def nkpc_w(pi, wout, mu_w, eps_w, w1, w2, w3, w4):
    w_res = w1*wout(-1) + (1-w1)*(wout(1) + pi(1)) - w2*pi + w3*pi(-1) - w4*mu_w + eps_w - wout
    return w_res


@solved(unknowns=['i'], targets=['taylor_res'])
def monetary(i, pi, y, yf, eps_i, rho, psi1, psi2, psi3):
    taylor_res = rho*i(-1) + (1-rho)*(psi1*pi + psi2*(y - yf)) + psi3*( y - y(-1) - yf + yf(-1) ) + eps_i - i
    return taylor_res


@solved(unknowns=['p', 'W'], targets=['p_res', 'W_res'])
def nominal(p, W, pi, w):
    p_res = p - p(-1) - pi
    W_res = W - w - p
    return p_res, W_res


@simple
def geneq(y, c, I, z, i, r, pi, w, wout, eps_g, cy, Iy, zy):
    goods_mkt = cy*c + Iy*I + zy*z + eps_g - y
    fisher = i - (r+pi(1))
    w_res = w - wout
    return goods_mkt, fisher, w_res


@simple
def geneq_f(yf, cf, If, zf, eps_g, cy, Iy, zy):
    goods_mkt_f = cy*cf + Iy*If + zy*zf + eps_g - yf
    return goods_mkt_f


@simple
def measurement(y, c, I, n, p, W, w, i, cy, Iy):
    y_dta = y
    c_dta = cy * c
    I_dta = Iy * I
    n_dta = n
    p_dta = p
    W_dta = W
    w_dta = w
    i_dta = i
    return y_dta, c_dta, I_dta, n_dta, p_dta, W_dta, i_dta, w_dta


@simple
def pc_inputs_sw(mu_p, fixed_cost, curvp):
    gap = -mu_p / (fixed_cost * curvp + 1)
    return gap


# maps structural parameters to coefficients of log-linear equations
def par_to_coefs(par):
    ss = par.copy()

    # Shocks
    ss['eps_a'] = 0  # productivity shock
    ss['eps_b'] = 0  # risk premium shock
    ss['eps_i'] = 0  # monetary shock
    ss['eps_g'] = 0  # government spending shock
    ss['eps_p'] = 0  # price markup shock
    ss['eps_w'] = 0  # wage markup shock
    ss['eps_I'] = 0  # investment shock

    # Other constants
    ss['beta'] = 1 / (ss['cbeta'] / 100 + 1)
    ss['ga'] = 1 + ss['ybar'] / 100  # Steady state growth rate
    ss['betabar'] = ss['beta'] * ss['ga'] ** (1 - ss['sig_c'])
    ss['Rkss'] = ss['ga'] ** ss['sig_c'] / ss['beta'] - (1 - ss['delta'])  # Steady state rental rate of capital
    ss['Ik'] = (1 - (1 - ss['delta']) / ss['ga']) * ss['ga']  # Investment over gdp
    ss['wss'] = (ss['alpha'] ** ss['alpha'] * (1 - ss['alpha']) ** (1 - ss['alpha']) /
                 ((1 + ss['fixed_cost']) * ss['Rkss'] ** ss['alpha'])) ** (1 / (1 - ss['alpha']))  # Labor share
    ss['nk'] = ((1 - ss['alpha']) / ss['alpha']) * (ss['Rkss'] / ss['wss'])  # n over k ratio
    ss['ky'] = (1 + ss['fixed_cost']) * (ss['nk']) ** (ss['alpha'] - 1)  # Capital over gdp ratio
    ss['Iy'] = ss['Ik'] * ss['ky']  # Investment over gdp
    ss['cy'] = 1 - ss['g'] - ss['Iy']  # Consumption over gdp
    ss['zy'] = ss['Rkss'] * ss['ky']  # z over gdp
    ss['whlc'] = (1 / ss['mu_w_ss']) * (1 - ss['alpha']) / ss['alpha'] * ss['Rkss'] * ss['ky'] / ss['cy']  # wl over c ratio
    ss['c1'] = (ss['la_c'] / ss['ga']) / (1 + ss['la_c'] / ss['ga'])
    ss['c2'] = ((ss['sig_c'] - 1) * ss['whlc']) / (ss['sig_c'] * (1 + ss['la_c'] / ss['ga']))
    ss['c3'] = (1 - ss['la_c'] / ss['ga']) / ((1 + ss['la_c'] / ss['ga']) * ss['sig_c'])
    ss['I1'] = 1 / (1 + ss['betabar'])
    ss['I2'] = ss['I1'] / (ss['ga'] ** 2 * ss['phi'])
    ss['q1'] = ss['betabar'] * (1 - ss['delta']) / ss['ga']
    ss['z1'] = (1 - ss['psi']) / ss['psi']
    ss['k1'] = (1 - ss['delta']) / ss['ga']
    ss['k2'] = (1 - (1 - ss['delta']) / ss['ga']) * ss['ga'] ** 2 * ss['phi']
    ss['pi1'] = ss['iota_p'] / (1 + ss['betabar'] * ss['iota_p'])
    ss['pi2'] = ss['betabar'] / (1 + ss['betabar'] * ss['iota_p'])
    ss['pi3'] = (1 - ss['betabar'] * ss['zeta_p']) * (1 - ss['zeta_p']) / \
                (ss['zeta_p'] * (ss['fixed_cost'] * ss['curvp'] + 1) * (1 + ss['betabar'] * ss['iota_p']))
    ss['w1'] = 1 / (1 + ss['betabar'])
    ss['w2'] = (1 + ss['betabar'] * ss['iota_w']) * ss['w1']
    ss['w3'] = ss['iota_w'] * ss['w1']
    ss['w4'] = (1 - ss['betabar'] * ss['zeta_w']) * (1 - ss['zeta_w']) / (
                (1 + ss['betabar']) * ss['zeta_w'] * ((ss['mu_w_ss'] - 1) * ss['curvw'] + 1))
    ss['ibar'] = 100 * ((1 + ss['pibar'] / 100) / (ss['beta'] * ss['ga'] ** (-ss['sig_c'])) - 1)

    # Aggregates in steady state
    ss['I'] = ss['If'] = 0
    ss['q'] = ss['qf'] = 0
    ss['k'] = ss['kf'] = 0
    ss['rk'] = ss['rkf'] = 0
    ss['r'] = ss['rf'] = 0
    ss['i'] = ss['ibar']
    ss['n'] = ss['nf'] = 0
    ss['w'] = ss['wf'] = ss['wout'] = 0
    ss['c'] = ss['cf'] = 0
    ss['y'] = ss['yf'] = 0
    ss['pi'] = ss['pi'] = 0
    ss['mu_p'] = ss['mu_p_f'] = 0
    ss['mu_w'] = ss['mu_w_f'] = 0
    ss['z'] = ss['zf'] = 0
    ss['p'] = 0
    ss['W'] = 0

    return ss


parameters = {
    # estimated structural parameters
       'phi': 5.74,  # capital adjustment cost
       'sig_c': 1.38,  # eis
       'la_c': 0.71,  # habit
       'zeta_w': 0.70,  # wage rigidity
       'sig_l': 1.83,  # labor supply parameter
       'zeta_p': 0.66,  # price rigidity
       'iota_w': 0.58,  # degree of wage indexation
       'iota_p': 0.24,  # degree of price indexation
       'psi': 0.54,  # capital utilization adjustment costs elasticity
       'fixed_cost': 0.60,  # fixed cost of production
       'psi1': 2.04,  # coefficient on inflation in Taylor rule
       'rho': 0.81,  # Taylor rule inertia
       'psi2': 0.08,  # coefficient on output gap in Taylor rule
       'psi3': 0.22,  # coefficient on output growth in Taylor rule
       'pibar': 0.78,  # avg quarterly inflation rate
       'cbeta': 0.16,  # parameter for discount factor
       'nbar': 0.53,
       'ybar': 0.43,  # quarterly growth rate
       'alpha': 0.19,  # capital share

    # estimated shock process parameters
       'rho_i': 0.15,  # persistence of interest rate shock
       'sig_i': 0.24,  # std of interest rate shock

    # calibrated parameters
       'delta': 0.025,  # depreciation rate
       'g': 0.18,  # government spending as a share of gdp
       'mu_w_ss': 1.5,  # wage markup
       'curvp': 10,  # curvature of Kimball for prices
       'curvw': 10,  # curvature of Kimball for wages
       }


unknowns_sw = ['w', 'wf', 'n', 'nf', 'r', 'rf']
targets_sw = ['goods_mkt', 'goods_mkt_f','fisher','w_res','mu_p_f','mu_w_f']
blocks_sw = [household, household_f, household_markup, firm, firm_f, nkpc_w, monetary, nominal, geneq, geneq_f, measurement, pc_inputs_sw]
ss_sw = par_to_coefs(parameters)

# %% old nk model
#
# @simple
# def standard_nk(c, n, y, x, w, i, p, pi, eis, frisch, alpha, elast, phi_pi, phi_y, eps_i):
#     x_share = (1 - alpha) * (elast - 1) / elast
#     euler_res = c - (c(1) - eis * (i - pi(1)))
#     labor_res = n / frisch + c / eis - w
#     prod_res = y - alpha * n - (1 - alpha) * x
#     cost_res = x - w - n
#     mkt_res = y - x_share * x - (1 - x_share) * c
#     taylor_res = i - (phi_pi * pi + phi_y * c + eps_i)
#     p_res = p - p(-1) - pi
#     MC = alpha * w + (1 - alpha) * p
#     return euler_res, labor_res, prod_res, cost_res, mkt_res, taylor_res, p_res, MC
#
#
# @simple
# def pc_inputs_nk(c, y, w, eis, alpha, elast):
#     discount = -c / eis + w
#     shift = (alpha * (1 - elast)  - 1) * w + y
#     gap = alpha * w
#     return discount, gap, shift
#
#
# unknowns_nk = ['c', 'n', 'y', 'x', 'w', 'i', 'p']
# targets_nk = ['euler_res', 'labor_res', 'prod_res', 'cost_res', 'mkt_res', 'taylor_res', 'p_res']
# blocks_nk = [standard_nk, pc_inputs_nk]
# ss_nk = {'c': 0, 'n': 0, 'y': 0, 'x': 0, 'w': 0, 'i': 0, 'pi': 0, 'p': 0,
#          'eis': 1, 'frisch': 0.2, 'alpha': 1, 'elast': 7, 'phi_pi': 1.5, 'phi_y': 0.125, 'eps_i': 0}
#
